"""
Tests for collection deletion functionality.

Tests the CollectionDeletionService methods:
- delete_collection
- delete_collection_index_only
"""

import pytest
import uuid

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from local_deep_research.database.models.library import (
    Document,
    DocumentChunk,
    DocumentCollection,
    Collection,
    CollectionFolder,
    SourceType,
    RAGIndex,
    EmbeddingProvider,
)
from local_deep_research.database.models.base import Base


@pytest.fixture
def test_engine():
    """Create an in-memory SQLite database for testing."""
    engine = create_engine("sqlite:///:memory:")
    Base.metadata.create_all(engine)
    return engine


@pytest.fixture
def test_session(test_engine):
    """Create a session for testing."""
    Session = sessionmaker(bind=test_engine)
    session = Session()
    yield session
    session.close()


@pytest.fixture
def source_type(test_session):
    """Create a source type for testing."""
    st = SourceType(
        id=str(uuid.uuid4()),
        name="user_upload",
        display_name="User Upload",
    )
    test_session.add(st)
    test_session.commit()
    return st


@pytest.fixture
def test_collection(test_session):
    """Create a test collection with embedding settings."""
    collection = Collection(
        id=str(uuid.uuid4()),
        name="Test Collection",
        description="A test collection",
        embedding_model="all-MiniLM-L6-v2",
        embedding_model_type=EmbeddingProvider.SENTENCE_TRANSFORMERS,
        embedding_dimension=384,
        chunk_size=500,
        chunk_overlap=50,
    )
    test_session.add(collection)
    test_session.commit()
    return collection


@pytest.fixture
def test_documents(test_session, source_type, test_collection):
    """Create multiple test documents in a collection."""
    docs = []
    for i in range(3):
        doc_id = str(uuid.uuid4())
        doc = Document(
            id=doc_id,
            source_type_id=source_type.id,
            document_hash=f"hash_{doc_id[:8]}",
            file_size=1024 * (i + 1),
            file_type="pdf",
            title=f"Test Document {i + 1}",
            filename=f"test{i + 1}.pdf",
            text_content=f"Content of document {i + 1}",
        )
        test_session.add(doc)
        test_session.flush()

        # Add to collection
        doc_collection = DocumentCollection(
            document_id=doc_id,
            collection_id=test_collection.id,
            indexed=True,
            chunk_count=5,
        )
        test_session.add(doc_collection)
        docs.append(doc)

    test_session.commit()
    return docs


@pytest.fixture
def test_chunks(test_session, test_collection, test_documents):
    """Create chunks for the test documents."""
    collection_name = f"collection_{test_collection.id}"
    chunks = []

    for doc in test_documents:
        for i in range(2):
            chunk = DocumentChunk(
                chunk_hash=f"chunk_{doc.id[:8]}_{i}_{uuid.uuid4().hex[:8]}",
                source_type="document",
                source_id=doc.id,
                collection_name=collection_name,
                chunk_text=f"Chunk {i} of {doc.title}",
                chunk_index=i,
                start_char=i * 100,
                end_char=(i + 1) * 100,
                word_count=10,
                embedding_id=str(uuid.uuid4()),
                embedding_model="all-MiniLM-L6-v2",
                embedding_model_type=EmbeddingProvider.SENTENCE_TRANSFORMERS,
            )
            test_session.add(chunk)
            chunks.append(chunk)

    test_session.commit()
    return chunks


@pytest.fixture
def test_rag_index(test_session, test_collection):
    """Create a RAG index for the collection."""
    collection_name = f"collection_{test_collection.id}"
    rag_index = RAGIndex(
        collection_name=collection_name,
        embedding_model="all-MiniLM-L6-v2",
        embedding_model_type=EmbeddingProvider.SENTENCE_TRANSFORMERS,
        embedding_dimension=384,
        index_path=f"/tmp/test_index_{test_collection.id}",
        index_hash=f"hash_{test_collection.id[:8]}",
        chunk_size=500,
        chunk_overlap=50,
        chunk_count=6,
        total_documents=3,
    )
    test_session.add(rag_index)
    test_session.commit()
    return rag_index


class TestCollectionDeletion:
    """Unit tests for collection deletion."""

    def test_delete_collection_preserves_documents(
        self, test_session, test_collection, test_documents
    ):
        """Documents remain after collection deletion."""
        collection_id = test_collection.id
        doc_ids = [d.id for d in test_documents]

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify documents still exist
        for doc_id in doc_ids:
            doc = test_session.query(Document).get(doc_id)
            assert doc is not None, f"Document {doc_id[:8]} should still exist"

    def test_delete_collection_removes_links(
        self, test_session, test_collection, test_documents
    ):
        """DocumentCollection links are deleted with collection."""
        collection_id = test_collection.id

        # Verify links exist
        links_before = (
            test_session.query(DocumentCollection)
            .filter_by(collection_id=collection_id)
            .count()
        )
        assert links_before == 3, "Should have 3 document links"

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify links are deleted (CASCADE)
        links_after = (
            test_session.query(DocumentCollection)
            .filter_by(collection_id=collection_id)
            .count()
        )
        assert links_after == 0, "Document links should be deleted"

    def test_delete_collection_removes_chunks(
        self, test_session, test_collection, test_documents, test_chunks
    ):
        """DocumentChunks for collection are deleted."""
        collection_id = test_collection.id
        collection_name = f"collection_{collection_id}"

        # Verify chunks exist
        chunks_before = (
            test_session.query(DocumentChunk)
            .filter_by(collection_name=collection_name)
            .count()
        )
        assert chunks_before == 6, "Should have 6 chunks (2 per document)"

        # Manually delete chunks (no FK, must be explicit)
        deleted = (
            test_session.query(DocumentChunk)
            .filter_by(collection_name=collection_name)
            .delete(synchronize_session=False)
        )
        assert deleted == 6

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify all cleaned up
        chunks_after = (
            test_session.query(DocumentChunk)
            .filter_by(collection_name=collection_name)
            .count()
        )
        assert chunks_after == 0

    def test_delete_collection_removes_rag_index(
        self, test_session, test_collection, test_rag_index
    ):
        """RAGIndex records are deleted with collection."""
        collection_id = test_collection.id
        collection_name = f"collection_{collection_id}"

        # Verify RAG index exists
        index_before = (
            test_session.query(RAGIndex)
            .filter_by(collection_name=collection_name)
            .first()
        )
        assert index_before is not None, "RAG index should exist"

        # Manually delete RAG index (should be done by service)
        test_session.delete(index_before)

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify RAG index is deleted
        index_after = (
            test_session.query(RAGIndex)
            .filter_by(collection_name=collection_name)
            .first()
        )
        assert index_after is None


class TestCollectionIndexDeletion:
    """Tests for deleting only the collection index."""

    def test_delete_index_keeps_collection(
        self, test_session, test_collection, test_chunks, test_rag_index
    ):
        """Collection remains after index deletion."""
        collection_id = test_collection.id
        collection_name = f"collection_{collection_id}"

        # Delete chunks
        test_session.query(DocumentChunk).filter_by(
            collection_name=collection_name
        ).delete(synchronize_session=False)

        # Delete RAG index
        test_session.delete(test_rag_index)
        test_session.commit()

        # Verify collection still exists
        collection = test_session.query(Collection).get(collection_id)
        assert collection is not None, "Collection should still exist"

    def test_delete_index_resets_document_status(
        self, test_session, test_collection, test_documents, test_chunks
    ):
        """DocumentCollection indexed status is reset."""
        collection_id = test_collection.id
        collection_name = f"collection_{collection_id}"

        # Delete chunks
        test_session.query(DocumentChunk).filter_by(
            collection_name=collection_name
        ).delete(synchronize_session=False)

        # Reset document indexed status
        updated = (
            test_session.query(DocumentCollection)
            .filter_by(collection_id=collection_id)
            .update({"indexed": False, "chunk_count": 0})
        )
        assert updated == 3, "Should update 3 document links"

        test_session.commit()

        # Verify status is reset
        links = (
            test_session.query(DocumentCollection)
            .filter_by(collection_id=collection_id)
            .all()
        )
        for link in links:
            assert link.indexed is False, "Should be marked as not indexed"
            assert link.chunk_count == 0, "Chunk count should be 0"


class TestCollectionFolderCascade:
    """Tests for CollectionFolder cascade deletion."""

    def test_delete_collection_removes_folders(
        self, test_session, test_collection
    ):
        """CollectionFolders are deleted with collection."""
        collection_id = test_collection.id

        # Add a folder to the collection
        folder = CollectionFolder(
            collection_id=collection_id,
            folder_path="/tmp/test_folder",
            recursive=True,
        )
        test_session.add(folder)
        test_session.commit()

        # Verify folder exists
        folder_before = (
            test_session.query(CollectionFolder)
            .filter_by(collection_id=collection_id)
            .first()
        )
        assert folder_before is not None, "Folder should exist"

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify folder is deleted (CASCADE)
        folder_after = (
            test_session.query(CollectionFolder)
            .filter_by(collection_id=collection_id)
            .first()
        )
        assert folder_after is None, "Folder should be deleted"


class TestOrphanedDocumentDeletion:
    """Tests for orphaned document deletion during collection deletion."""

    def test_delete_collection_deletes_orphaned_documents(
        self, test_session, test_collection, test_documents, source_type
    ):
        """Documents only in the deleted collection are deleted."""
        collection_id = test_collection.id
        doc_ids = [d.id for d in test_documents]

        # Verify documents exist and are only in one collection
        for doc_id in doc_ids:
            link_count = (
                test_session.query(DocumentCollection)
                .filter_by(document_id=doc_id)
                .count()
            )
            assert link_count == 1, (
                "Each doc should be in exactly one collection"
            )

        # Manually delete links first
        test_session.query(DocumentCollection).filter_by(
            collection_id=collection_id
        ).delete(synchronize_session=False)

        # Check for orphaned documents and delete them
        orphaned_deleted = 0
        for doc_id in doc_ids:
            remaining = (
                test_session.query(DocumentCollection)
                .filter_by(document_id=doc_id)
                .count()
            )
            if remaining == 0:
                # Delete document
                doc = test_session.query(Document).get(doc_id)
                if doc:
                    test_session.delete(doc)
                    orphaned_deleted += 1

        # Delete collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify orphaned documents were deleted
        assert orphaned_deleted == 3, (
            "All 3 orphaned documents should be deleted"
        )

        for doc_id in doc_ids:
            doc = test_session.query(Document).get(doc_id)
            assert doc is None, (
                f"Orphaned document {doc_id[:8]} should be deleted"
            )

    def test_delete_collection_preserves_shared_documents(
        self, test_session, test_collection, test_documents, source_type
    ):
        """Documents in other collections are preserved."""
        collection_id = test_collection.id
        doc_ids = [d.id for d in test_documents]

        # Create a second collection and link the first document to it
        second_collection = Collection(
            id=str(uuid.uuid4()),
            name="Second Collection",
        )
        test_session.add(second_collection)
        test_session.flush()

        # Link first document to second collection
        shared_doc_id = doc_ids[0]
        link = DocumentCollection(
            document_id=shared_doc_id,
            collection_id=second_collection.id,
            indexed=False,
            chunk_count=0,
        )
        test_session.add(link)
        test_session.commit()

        # Delete links for first collection
        test_session.query(DocumentCollection).filter_by(
            collection_id=collection_id
        ).delete(synchronize_session=False)

        # Check orphaned docs
        preserved_count = 0
        deleted_count = 0
        for doc_id in doc_ids:
            remaining = (
                test_session.query(DocumentCollection)
                .filter_by(document_id=doc_id)
                .count()
            )
            if remaining == 0:
                doc = test_session.query(Document).get(doc_id)
                if doc:
                    test_session.delete(doc)
                    deleted_count += 1
            else:
                preserved_count += 1

        # Delete first collection
        collection = test_session.query(Collection).get(collection_id)
        test_session.delete(collection)
        test_session.commit()

        # Verify shared document was preserved
        shared_doc = test_session.query(Document).get(shared_doc_id)
        assert shared_doc is not None, "Shared document should be preserved"
        assert preserved_count == 1, "One document should be preserved"
        assert deleted_count == 2, "Two orphaned documents should be deleted"
